# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import threading
from unittest import SkipTest

import jax
import jax._src.test_util as jtu
import numpy as np
from absl.testing import absltest
from jax import config, lax
from jax import numpy as jnp
from jax._src import core, xla_bridge
from jax._src.lib import xla_client
from jax.experimental import host_callback as hcb

from ..utils.timer_wrapper import jax_op_timer, partial_timed

config.parse_flags_with_absl()


class InfeedTest(jtu.JaxTestCase):
    def setUp(self):
        # if xla_bridge.using_pjrt_c_api():
        #     raise SkipTest("infeed not implemented in PJRT C API")
        super().setUp()

    @jax.numpy_rank_promotion(
        "allow"
    )  # Test explicitly exercises implicit rank promotion.
    def testInfeed(self):
        @jax.jit
        def f(x):
            token = lax.create_token(x)
            timer = jax_op_timer()
            with timer:
                (y,), token = lax.infeed(
                    token, shape=(core.ShapedArray((3, 4), jnp.float32),)
                )
                timer.gen.send((y, token))

            timer = jax_op_timer()
            with timer:
                (z,), test = lax.infeed(
                    token, shape=(core.ShapedArray((3, 1, 1), jnp.float32),)
                )
                timer.gen.send((z, test))
            return x + y + z

        x = np.float32(1.5)
        y = np.reshape(
            np.arange(12, dtype=np.float32), (3, 4)
        )  # self.rng().randn(3, 4).astype(np.float32)
        z = self.rng().randn(3, 1, 1).astype(np.float32)
        device = jax.local_devices()[0]
        device.transfer_to_infeed((y,))
        device.transfer_to_infeed((z,))
        self.assertAllClose(f(x), x + y + z)

    def testInfeedPytree(self):
        x = np.float32(1.5)
        y = np.reshape(np.arange(12, dtype=np.int16), (3, 4))
        to_infeed = dict(a=x, b=y)
        to_infeed_shape = dict(
            a=core.ShapedArray((), dtype=np.float32),
            b=core.ShapedArray((3, 4), dtype=np.int16),
        )

        @jax.jit
        def f(x):
            token = lax.create_token(x)
            res, token = lax.infeed(token, shape=to_infeed_shape)
            return res

        device = jax.local_devices()[0]
        # We must transfer the flattened data, as a tuple!!!
        flat_to_infeed, _ = jax.tree_util.tree_flatten(to_infeed)
        device.transfer_to_infeed(tuple(flat_to_infeed))
        self.assertAllClose(f(x), to_infeed)

    @jax.numpy_rank_promotion(
        "allow"
    )  # Test explicitly exercises implicit rank promotion.
    def testInfeedThenOutfeed(self):
        hcb.stop_outfeed_receiver()

        @jax.jit
        def f(x):
            token = lax.create_token(x)
            y, token = lax.infeed(
                token, shape=core.ShapedArray((3, 4), jnp.float32))
            token = lax.outfeed(token, y + np.float32(1))
            return x - 1

        x = np.float32(7.5)
        y = self.rng().randn(3, 4).astype(np.float32)
        execution = threading.Thread(target=lambda: f(x))
        execution.start()
        device = jax.local_devices()[0]
        device.transfer_to_infeed((y,))
        (out,) = device.transfer_from_outfeed(
            xla_client.shape_from_pyval(
                (y,)).with_major_to_minor_layout_if_absent()
        )
        execution.join()
        self.assertAllClose(out, y + np.float32(1))

    def testInfeedThenOutfeedInALoop(self):
        hcb.stop_outfeed_receiver()

        def doubler(_, token):
            timer = jax_op_timer()
            with timer:
                y, token = lax.infeed(
                    token, shape=core.ShapedArray((3, 4), jnp.float32))
                timer.gen.send((y, token))
            return lax.outfeed(token, y * np.float32(2))

        @jax.jit
        def f(n):
            token = lax.create_token(n)
            token = lax.fori_loop(0, n, doubler, token)
            return n

        device = jax.local_devices()[0]
        n = 10
        execution = threading.Thread(target=lambda: f(n))
        execution.start()
        for _ in range(n):
            x = self.rng().randn(3, 4).astype(np.float32)
            device.transfer_to_infeed((x,))
            (y,) = device.transfer_from_outfeed(
                xla_client.shape_from_pyval(
                    (x,)).with_major_to_minor_layout_if_absent()
            )
            self.assertAllClose(y, x * np.float32(2))
        execution.join()


if __name__ == "__main__":
    absltest.main(testLoader=jtu.JaxTestLoader())
